import h5py
import argparse
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from itertools import product
import numpy as np

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", required=True)
    parser.add_argument("--output", required=True)
    parser.add_argument("--dataset", required=True)
    return parser.parse_args()

def viz():
    args = parse_args()
    in_file = args.input
    out_file = args.output

    gt = h5py.File(args.dataset, "r")["ground_truth"]
    F_gt = gt["logL"][...]
    pi_gt = gt["pies"][...]
    sigma_gt = np.sqrt(gt["sigma2"][...])

    f = h5py.File(in_file, "r")
    theta = f["theta"]
    F = f["train_F"]

    F_ax = plt.subplot(3, 2, 2)
    pi_ax = plt.subplot(3, 2, 4)
    sigma_ax = plt.subplot(3, 2, 6)
    sigma_ax.plot(theta["sigma"])
    sigma_ax.plot([0,len(theta["sigma"])], [sigma_gt, sigma_gt], '--')
    for pi_h in theta["pies"][...].T:
        pi_ax.plot(pi_h)
    pi_ax.plot([0,len(theta["pies"][:,0])], [pi_gt, pi_gt], '--')
    F_ax.plot(F)
    F_ax.plot([0,len(F)], [F_gt,F_gt], '--')

    axes = []
    vmax = np.abs(theta["W"][-1]).max()
    for r, c in product(range(2), range(4)):
        w_ax = plt.subplot(2, 8, r*8 + c + 1)
        axes.append(w_ax)
        im = plt.imshow(theta["W"][-1, :, r*c + c].reshape(4, 4), vmin=-vmax, vmax=vmax, cmap="coolwarm")
        plt.axis("off")
    plt.colorbar(im, ax=axes)
    plt.savefig(out_file)

if __name__ == "__main__":
    viz()
